import os
import argparse
import pickle

import numpy as np
import matplotlib.pyplot as plt
import torch
import gymnasium as gym

from config.bandit import args_bandit_rl2
# from metalearner import MetaLearner
from metalearner_general import MetaLearner
# from metalearner_stat_bernoulli import MetaLearner

from utils.evaluation import load_trained_network
from utils.evaluation import rollout_one_episode_rl2
from utils.evaluation import visualize_policy_BlockBandit_rl2
# from utils.evaluation import plot_rnn_hidden_states, get_rnn_connectivity, plot_rnn_connectivity
from utils.helpers import plot_training_curves, plot_evaluation_curves



def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # -- ARGS --
    parser = args_bandit_rl2.get_parser()
    args = parser.parse_args()
    
    # set env
    env = gym.make(args.env_name)

    # -- shared parameters
    args.max_episode_steps = env.unwrapped.max_episode_steps
    args.policy_num_steps_per_update = env.unwrapped.max_episode_steps
    args.time_as_state = False
    args.policy_algorithm = 'a2c'
    args.policy_net_activation_function = 'tanh'
    # feature extractor
    embed_dim = 0
    args.action_embed_dim = embed_dim
    args.state_embed_dim = embed_dim
    args.reward_embed_dim = embed_dim
    # logging
    # args.eval_interval = args.num_updates / 15
    num_evals = 10
    num_saves = 5
    eval_ids = [-1]
    eval_ids_train = np.geomspace(
        1, args.num_updates, num_evals, 
        endpoint=True, dtype=int).tolist()
    args.eval_ids = eval_ids + eval_ids_train
    args.save_intermediate_models = True
    args.save_interval = args.num_updates / num_saves


    # -- TRAINING --
    # initialize metalearner
    metalearner = MetaLearner(args)
    out_dir = metalearner.logger.full_output_folder
    
    # training
    train_stats, evaluation_stats = metalearner.train()
    
    # save training history
    with open(os.path.join(out_dir, 'train_stats.pickle'), 'wb') as f:
        pickle.dump(train_stats, f)
    with open(os.path.join(out_dir, 'evaluation_stats.pickle'), 'wb') as f:
        pickle.dump(evaluation_stats, f)

    # plot training curve
    plot_training_curves(
        args, out_dir,
        train_stats['episode_returns'], train_stats['actor_losses'], 
        train_stats['critic_losses'], train_stats['policy_entropies'],
        train_stats['activity_l2_loss']
    )
    plot_evaluation_curves(
        out_dir=out_dir,
        eval_epoch_ids=evaluation_stats['eval_epoch_ids'],
        empirical_return_avgs=evaluation_stats['empirical_return_avgs'],
        empirical_return_stds=evaluation_stats['empirical_return_stds'],
        num_eval_runs=args.num_eval_envs
    )


    # -- EVALUATION --
    if args.exp_label in ['rl2', 'noisy_rl2']:
        a2crnn = metalearner.policy.actor_critic

    trained_model_path = out_dir
    # document
    model_name = trained_model_path.split('/')[-1]
    print(f'model_name: {model_name}')
    lines = []
    lines.append(f'model: {model_name}')
    lines.append(f'training_env: {args.env_name}')
    lines.append(f'policy_entropy_loss_coeff: {args.policy_entropy_loss_coeff}')
    lines.append(f'policy_critic_loss_coeff: {args.policy_critic_loss_coeff}')
    if args.exp_label == 'rl2':
        lines.append(f'shared_rnn: {args.shared_rnn}')
        lines.append(f'rnn_hidden_dim: {args.rnn_hidden_dim}')
        lines.append(f'{a2crnn}')
    
    with open(os.path.join(trained_model_path, 'model_summary.txt'), 'w') as f:
        for line in lines:
            f.write(line)
            f.write('\n')


    plt.close('all')                
    print(f'All training completed')


if __name__ == "__main__":
    main()